package org.zjvis.datascience.service;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.maxmind.geoip2.DatabaseReader;
import com.maxmind.geoip2.model.CityResponse;
import com.maxmind.geoip2.record.Location;
import com.mayabot.nlp.common.Maps;
import com.mayabot.nlp.fasttext.FastText;
import com.mayabot.nlp.fasttext.ScoreLabelPair;
import net.sourceforge.pinyin4j.PinyinHelper;
import net.sourceforge.pinyin4j.format.HanyuPinyinOutputFormat;
import net.sourceforge.pinyin4j.format.HanyuPinyinToneType;
import net.sourceforge.pinyin4j.format.exception.BadHanyuPinyinOutputFormatCombination;
import org.apache.commons.compress.utils.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.util.CollectionUtil;
import org.zjvis.datascience.common.util.DatasetUtil;

import javax.annotation.PostConstruct;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Field;
import java.net.InetAddress;
import java.util.AbstractMap.SimpleEntry;
import java.util.*;

/**
 * @description UrbanData 地理信息 Service
 * @date 2021-11-29
 */
@Service
public class UrbanDataService {

    private final static Logger logger = LoggerFactory.getLogger("UrbanDataService");

    private Map<String, JSONObject> country;
    private Map<String, JSONObject> province;
    private Map<String, JSONObject> city;
    private DatabaseReader reader;
    private FastText model;
    private HanyuPinyinOutputFormat pinyinFormat;

    @Autowired
    private MinioService minioService;

    @Autowired
    public FastTextService fastTextService;

    @Autowired
    public SemanticService semanticService;

    @PostConstruct
    public void initDictionary() throws Exception {
        country = Maps.newHashMap();
        province = Maps.newHashMap();
        city = Maps.newHashMap();
        readJson("city");
        readJson("country");
        readJson("province");

        model = fastTextService.getAutojoinModel();
        pinyinFormat = new HanyuPinyinOutputFormat();
        pinyinFormat.setToneType(HanyuPinyinToneType.WITHOUT_TONE);

        InputStream is = minioService.getObject("vis-platform", "GeoLite2-City.mmdb");
        reader = new DatabaseReader.Builder(is).build();
    }

    public DatabaseReader getIpDatabase(){
        return reader;
    }

    @SuppressWarnings("unchecked")
    private void readJson(String file) {
        InputStream fis = null;
        try {
            String url = String.format("dictionary/%s.json", file);
            fis = this.getClass().getClassLoader().getResourceAsStream(url);
            assert fis != null;
            BufferedReader reader =  new BufferedReader(new InputStreamReader(fis));
            StringBuilder sb = new StringBuilder();
            while (reader.ready()) {
                sb.append(reader.readLine().trim());
            }
            JSONArray objs = JSONArray.parseArray(sb.toString());
            JSONObject jsonObject = objs.getJSONObject(0);
            List<String> keys = Lists.newArrayList(jsonObject.keySet().iterator());
            keys.remove("id");
            keys.remove("center");
            Field field = this.getClass().getDeclaredField(file);
            Map<String, JSONObject> dictionary = (Map<String, JSONObject>)field.get(this);
            for (int i = 0; i < objs.size(); i++) {
                JSONObject item = objs.getJSONObject(i);
                for (String key : keys) {
                    Object itemValue = item.getObject(key, Object.class);
                    if (itemValue instanceof String) {
                        dictionary.put((String)itemValue , item);
                    } else if (itemValue instanceof JSONArray) {
                        JSONArray array = (JSONArray)itemValue;
                        for (int j = 0; j < array.size(); j++) {
                            dictionary.put(array.getString(j), item);
                        }
                    }
                }
            }
        } catch (Exception e) {
            logger.error("init Dictionary fail! errorMsg:{}", e.getMessage());
        } finally {
            if (null != fis) {
                try {
                    fis.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 判断是否匹配语义
     *
     * @param value    值
     * @param semantic 语义
     * @return boolean
     */
    @SuppressWarnings("unchecked")
    public boolean ifMatching(String value, String semantic) {
        value = DatasetUtil.preprocessStr(value);
        if ("country".equals(semantic)) {
            return semanticService.getCountryList().contains(value);
        } else if ("province".equals(semantic)) {
            return semanticService.getProvinceList().contains(value);
        } else if ("city".equals(semantic)) {
            return semanticService.getCityList().contains(value);
        } else {
            return false;
        }
    }

    /**
     * 为不匹配的值推荐修改值
     *
     * @param value    选择的不匹配值
     * @param semantic 字段语义
     * @return 推荐值列表
     */
    @SuppressWarnings("unchecked")
    public List<String> recommendValue(String value, String semantic)
        throws BadHanyuPinyinOutputFormatCombination {
        Set<String> dict = new HashSet<>();
        if ("country".equals(semantic)) {
            dict = semanticService.getCountryList();
        } else if ("province".equals(semantic)) {
            dict = semanticService.getProvinceList();
        } else if ("city".equals(semantic)) {
            dict = semanticService.getCityList();
        } else {
            return null;
        }

        List<SimpleEntry<Double, String>> list = new ArrayList<>();
        if ("city".equals(semantic)) {
            List<ScoreLabelPair> neighbors = model.nearestNeighbor(value, 10);
            for (ScoreLabelPair neighbor : neighbors) {
                if (neighbor.getScore() > 0.5 && dict.contains(neighbor.getLabel())) {
                    list.add(new SimpleEntry<Double, String>((double) neighbor.getScore(),
                        neighbor.getLabel()));
                }
            }
        } else {
            for (String key : dict) {
                list.add(calculateScore(value, key));
            }
        }
        list.sort(new ScoreComparator());
        List<String> result = new ArrayList<>();
        for (SimpleEntry<Double, String> l : list) {
            if (value.equals(l.getValue())) {
                continue;
            }
            result.add(l.getValue());
            if (result.size() > 1) {
                break;
            }
        }
        result.add(PinyinHelper.toHanYuPinyinString(value, pinyinFormat, "", false));
        return result;
    }

    /**
     * 计算分值
     *
     * @param value 待比较的值
     * @param key   字典里的一项
     * @return 分值包成的entry
     */
    private SimpleEntry<Double, String> calculateScore(String value,
        String key) {
        double score = fastTextService.getFasttext(model, value, key);
        return new SimpleEntry<>(score, key);
    }

    /**
     * 分值比较器，倒序排列
     */
    public static class ScoreComparator implements
        Comparator<SimpleEntry<Double, String>> {

        @Override
        public int compare(SimpleEntry<Double, String> o1,
            SimpleEntry<Double, String> o2) {
            return o2.getKey().compareTo(o1.getKey());
        }
    }

    /**
     * 获取经纬度
     * @param semantic 语义类型
     * @param key 数据集的值
     * @return 经纬度
     */
    @SuppressWarnings("unchecked")
    public String getCoordinate(String semantic, String key) {
        if ("ip".equals(semantic)) {
            try {
                InetAddress ipAddress = InetAddress.getByName(key);
                CityResponse city = reader.city(ipAddress);
                Location location = city.getLocation();
                return location.getLongitude() + "," + location.getLatitude();
            } catch (Exception e) {
                return "";
            }
        } else {
            String ret = "";
            if (semantic.equals("postcode")){
                ret = getCoordinateModule("city", key);
            } else {
                ret = getCoordinateModule(semantic, key);
            }
            return ret;
        }
    }

    /**
     * 获取经纬度子模块
     * @param semantic 语义类型
     * @param key 数据集的值
     * @return 经纬度
     */
    @SuppressWarnings("unchecked")
    public String getCoordinateModule(String semantic, String key){
        try {
            Field field = this.getClass().getDeclaredField(semantic);
            Map<String, JSONObject> map = (Map<String, JSONObject>) field.get(this);
            JSONObject jsonObject = map.get(key);
            JSONArray center = jsonObject.getJSONArray("center");
            if (CollectionUtil.isEmpty(center)) {
                return "";
            }
            return center.getString(0) + "," + center.getString(1);
        } catch (Exception e) {
            logger.error("UrbanDataService.getCoordinate() errorMsg={}", e.getMessage());
            return "";
        }
    }
}
